{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Flatten, Conv2D, BatchNormalization, Dropout\n",
    "from tensorflow.keras import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.compat.v1.Session(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "x_train = np.concatenate((x_train, x_test), axis=0)\n",
    "y_train = np.concatenate((y_train, y_test), axis=0)\n",
    "\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
    "x_test = x_train\n",
    "y_test = y_train\n",
    "\n",
    "# Add a channels dimension\n",
    "x_train = x_train[..., tf.newaxis]\n",
    "x_test = x_test[..., tf.newaxis]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_train, y_train)).shuffle(10000).batch(64)\n",
    "\n",
    "test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(Model):\n",
    "  def __init__(self):\n",
    "    super(MyModel, self).__init__()\n",
    "    self.conv1 = Conv2D(32,kernel_size=3,activation='relu',input_shape=(28,28,1))\n",
    "    self.bn1 = BatchNormalization()\n",
    "    self.conv2 = Conv2D(32,kernel_size=3,activation='relu')\n",
    "    self.bn2 = BatchNormalization()\n",
    "    self.conv3 = Conv2D(32,kernel_size=5,strides=2,padding='same',activation='relu')\n",
    "    self.bn3 = BatchNormalization()\n",
    "    self.drop1 = Dropout(0.5)\n",
    "    \n",
    "    self.conv4 = Conv2D(64,kernel_size=3,activation='relu')\n",
    "    self.bn4 = BatchNormalization()\n",
    "    self.conv5 = Conv2D(64,kernel_size=3,activation='relu')\n",
    "    self.bn5 = BatchNormalization()\n",
    "    self.conv6 = Conv2D(64,kernel_size=5,strides=2,padding='same',activation='relu')\n",
    "    self.bn6 = BatchNormalization()\n",
    "    self.drop2 = Dropout(0.5)\n",
    "    \n",
    "    self.flatten = Flatten()\n",
    "    self.d1 = Dense(128, activation='relu')\n",
    "    self.bn7 = BatchNormalization()\n",
    "    self.drop3 = Dropout(0.5)\n",
    "    self.d2 = Dense(10, activation='softmax')\n",
    "\n",
    "  def call(self, x):\n",
    "    x = self.conv1(x)\n",
    "    x = self.bn1(x)\n",
    "    x = self.conv2(x)\n",
    "    x = self.bn2(x)\n",
    "    x = self.conv3(x)\n",
    "    x = self.bn3(x)\n",
    "    x = self.drop1(x)\n",
    "    \n",
    "    x = self.conv4(x)\n",
    "    x = self.bn4(x)\n",
    "    x = self.conv5(x)\n",
    "    x = self.bn5(x)\n",
    "    x = self.conv6(x)\n",
    "    x = self.bn6(x)\n",
    "    x = self.drop2(x)\n",
    "    \n",
    "    x = self.flatten(x)\n",
    "    x = self.d1(x)\n",
    "    x = self.bn7(x)\n",
    "    x = self.drop3(x)\n",
    "    return self.d2(x)\n",
    "\n",
    "# Create an instance of the model\n",
    "model = MyModel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f45271a85c0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_weights('myweight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()\n",
    "\n",
    "initial_learning_rate = 0.0005\n",
    "lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n",
    "    initial_learning_rate,\n",
    "    decay_steps=100000,\n",
    "    decay_rate=0.96,\n",
    "    staircase=True)\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(images, labels):\n",
    "  with tf.GradientTape() as tape:\n",
    "    predictions = model(images)\n",
    "    loss = loss_object(labels, predictions)\n",
    "  gradients = tape.gradient(loss, model.trainable_variables)\n",
    "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "  train_loss(loss)\n",
    "  train_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def test_step(images, labels):\n",
    "  predictions = model(images)\n",
    "  t_loss = loss_object(labels, predictions)\n",
    "\n",
    "  test_loss(t_loss)\n",
    "  test_accuracy(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Layer my_model is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n",
      "Epoch 1, Loss: 0.0, Accuracy: 0.0, Test Loss: 1.2565988072310574e-07, Test Accuracy: 100.0\n",
      "Epoch 2, Loss: 0.0, Accuracy: 0.0, Test Loss: 1.2565988072310574e-07, Test Accuracy: 100.0\n",
      "Epoch 3, Loss: 0.0, Accuracy: 0.0, Test Loss: 1.2565988072310574e-07, Test Accuracy: 100.0\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-72cce7085dfb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;31m#     train_step(images, labels)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m   \u001b[0;32mfor\u001b[0m \u001b[0mtest_images\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_labels\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtest_ds\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m     \u001b[0mtest_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_images\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_labels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/data/ops/iterator_ops.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    620\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    621\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# For Python 3 compatibility\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 622\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    623\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    624\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_next_internal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/data/ops/iterator_ops.py\u001b[0m in \u001b[0;36mnext\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    664\u001b[0m     \u001b[0;34m\"\"\"Returns a nested structure of `Tensor`s containing the next element.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    665\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 666\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_next_internal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    667\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOutOfRangeError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    668\u001b[0m       \u001b[0;32mraise\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/data/ops/iterator_ops.py\u001b[0m in \u001b[0;36m_next_internal\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    649\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterator_resource\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    650\u001b[0m             \u001b[0moutput_types\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_flat_output_types\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 651\u001b[0;31m             output_shapes=self._flat_output_shapes)\n\u001b[0m\u001b[1;32m    652\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    653\u001b[0m       \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_dataset_ops.py\u001b[0m in \u001b[0;36miterator_get_next_sync\u001b[0;34m(iterator, output_types, output_shapes, name)\u001b[0m\n\u001b[1;32m   2657\u001b[0m         \u001b[0m_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_context_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_thread_local_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2658\u001b[0m         \u001b[0;34m\"IteratorGetNextSync\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_ctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_post_execution_callbacks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2659\u001b[0;31m         \"output_types\", output_types, \"output_shapes\", output_shapes)\n\u001b[0m\u001b[1;32m   2660\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0m_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2661\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0m_core\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_FallbackException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "EPOCHS = 10\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "  for images, labels in train_ds:\n",
    "    train_step(images, labels)\n",
    "\n",
    "  for test_images, test_labels in test_ds:\n",
    "    test_step(test_images, test_labels)\n",
    "\n",
    "  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'\n",
    "  print(template.format(epoch+1,\n",
    "                        train_loss.result(),\n",
    "                        train_accuracy.result()*100,\n",
    "                        test_loss.result(),\n",
    "                        test_accuracy.result()*100))\n",
    "\n",
    "  # Reset the metrics for the next epoch\n",
    "  train_loss.reset_states()\n",
    "  train_accuracy.reset_states()\n",
    "  test_loss.reset_states()\n",
    "  test_accuracy.reset_states()\n",
    "#   if "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights('myweight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
